Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 e4m3_fnuz support for rocm #2588

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Fp8 e4m3_fnuz support for rocm #2588

wants to merge 1 commit into from

Conversation

mht-sharma
Copy link
Collaborator

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
)


class Fp8Linear(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cleaner to have a separate Fp8LinearRocm?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

@@ -92,9 +123,17 @@ def get_weights(self, weights: "Weights", prefix: str):
.reshape(-1)
.expand(w.shape[0])
)
try:
input_scale = weights.get_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights also has _has_tensor maybe we should make it public and use it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for try: [...]get_tensor below.

@@ -72,6 +99,10 @@ def fp8_quantize(
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should wire up scale at some point for CUDA as well.

bias=self.bias,
)

if type(output) is tuple and len(output) == 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this change between torch versions or is output for AMD different?

Suggested change
if type(output) is tuple and len(output) == 2:
if isinstance(output, tuple) and len(output) == 2:

self.input_scale,
self.activation_scale_ub,
bias,
self.dtype,
)


class Fp8Linear(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, it depends a bit on how much conditional code we end up with. We did separate FP8 Marlin for this reason.

@@ -62,7 +62,7 @@ def from_unquant(cls, weight, bias, dtype):
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)

@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type.

These arguments get a bit messy. It's easy to mix up a tensor or a float (which was already happening here?). Maybe we should switch these to kwargs-only so that the call sites need to be explicit (+ type annotations).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants